{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "4e5eec8e",
   "metadata": {
    "tags": []
   },
   "source": [
    "# 图像曝光增强\n",
    "功能介绍：使用模型对曝光不足的输入图片进行HDR效果增强。    \n",
    "样例输入：png图像。     \n",
    "样例输出：增强后png图像。  \n",
    "## 前期准备  \n",
    "基础镜像的样例目录中已包含转换后的om模型以及测试图片，如果直接运行，可跳过此步骤。如果需要重新转换模型，可参考如下步骤：\n",
    "1. 获取模型和测试数据：我们可以在[这个链接](https://ascend-repo.obs.cn-east-2.myhuaweicloud.com/Atlas%20200I%20DK%20A2/DevKit/downloads/23.0.RC1/Ascend-devkit_23.0.RC1_downloads.xlsx)的表格中找到本样例的依赖文件，将模型和测试数据下载到本地，将image_HDR_enhance.pb放入model目录，将测试图片放入data目录。\n",
    "\n",
    "2. 模型转换：进入model目录，执行：  \n",
    "    ```shell\n",
    "    atc --model=./image_HDR_enhance.pb --framework=3 --output=image_HDR_enhance --soc_version=Ascend310B1 --input_shape=\"input:1,512,512,3\" --input_format=NHWC --output_type=FP32\n",
    "    ```\n",
    "\n",
    "    其中各个参数具体含义如下：  \n",
    "    * --model：原始模型文件。\n",
    "    * --framework：原始框架类型，0:Caffe; 1:MindSpore; 2:Tensorflow; 5:Onnx。\n",
    "    * --output：离线推理om模型文件路径。\n",
    "    * --soc_version：昇腾AI处理器型号，填写\"Ascend310B1\"。\n",
    "    * --input_shape：模型输入节点名称和shape。\n",
    "    * --input_format：输入Tensor的内存排列方式。\n",
    "    * --output_type：指定网络输出数据类型或指定某个输出节点的输出类型。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "94767da9",
   "metadata": {},
   "source": [
    "## 模型推理实现  \n",
    "1.导入需要的第三方库以及AscendCL推理所需文件。 "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "220f9b53",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "import cv2\n",
    "import numpy as np\n",
    "import os\n",
    "import time\n",
    "import matplotlib.pyplot as plt\n",
    "import acl\n",
    "\n",
    "import acllite_utils as utils\n",
    "import constants as constants\n",
    "from acllite_model import AclLiteModel\n",
    "from acllite_resource import resource_list"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3fd7f142",
   "metadata": {},
   "source": [
    "2.定义资源管理类，包括初始化acl、释放acl资源功能。  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a69ea811",
   "metadata": {},
   "outputs": [],
   "source": [
    "class AclLiteResource:\n",
    "    \"\"\"\n",
    "    AclLiteResource类\n",
    "    \"\"\"\n",
    "    def __init__(self, device_id=0):\n",
    "        self.device_id = device_id\n",
    "        self.context = None\n",
    "        self.stream = None\n",
    "        self.run_mode = None\n",
    "        \n",
    "    def init(self):\n",
    "        \"\"\"\n",
    "        初始化资源\n",
    "        \"\"\"\n",
    "        print(\"init resource stage:\")\n",
    "        ret = acl.init() # acl初始化\n",
    "\n",
    "        ret = acl.rt.set_device(self.device_id) # 指定运算的device\n",
    "        utils.check_ret(\"acl.rt.set_device\", ret)\n",
    "\n",
    "        self.context, ret = acl.rt.create_context(self.device_id) # 创建context\n",
    "        utils.check_ret(\"acl.rt.create_context\", ret)\n",
    "\n",
    "        self.stream, ret = acl.rt.create_stream() # 创建stream\n",
    "        utils.check_ret(\"acl.rt.create_stream\", ret)\n",
    "\n",
    "        self.run_mode, ret = acl.rt.get_run_mode() # 获取运行模式\n",
    "        utils.check_ret(\"acl.rt.get_run_mode\", ret)\n",
    "\n",
    "        print(\"Init resource success\")\n",
    "\n",
    "    def __del__(self):\n",
    "        print(\"acl resource release all resource\")\n",
    "        resource_list.destroy()\n",
    "        if self.stream:\n",
    "            print(\"acl resource release stream\")\n",
    "            acl.rt.destroy_stream(self.stream) # 销毁stream\n",
    "\n",
    "        if self.context:\n",
    "            print(\"acl resource release context\")\n",
    "            acl.rt.destroy_context(self.context) # 释放context\n",
    "\n",
    "        print(\"Reset acl device \", self.device_id)\n",
    "        acl.rt.reset_device(self.device_id) # 释放device\n",
    "        \n",
    "        print(\"Release acl resource success\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0ee64aff",
   "metadata": {},
   "source": [
    "3.实现具体推理功能，包含预处理、推理、后处理等功能。  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "08a8c26b",
   "metadata": {},
   "outputs": [],
   "source": [
    "from download import download\n",
    "\n",
    "# 获取模型om文件\n",
    "model_url = \"https://modelers.cn/coderepo/web/v1/file/MindSpore-Lab/cluoud_obs/main/media/examples/mindspore-courses/orange-pi-mindspore/04-HDR/image_HDR_enhance.zip\"\n",
    "download(model_url, \"./model\", kind=\"zip\", replace=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "954d92f0",
   "metadata": {},
   "outputs": [],
   "source": [
    "path = os.getcwd()\n",
    "input_w = 512   # 推理输入width\n",
    "input_h = 512   # 推理输入height\n",
    "INPUT_DIR = os.path.join(path, 'data/') # 输入路径\n",
    "OUTPUT_DIR = os.path.join(path, 'out/') # 输出路径\n",
    "\n",
    "def pre_process(dir_name, input_h, input_w):\n",
    "    \"\"\"\n",
    "    预处理\n",
    "    \"\"\"\n",
    "    BGR = cv2.imread(dir_name).astype(np.float32)\n",
    "    h = BGR.shape[0]\n",
    "    w = BGR.shape[1]\n",
    "    # 进行归一化、缩放、颜色转换\n",
    "    BGR = BGR / 255.0\n",
    "    BGR = cv2.resize(BGR, (input_h, input_w))\n",
    "    RGB = cv2.cvtColor(BGR, cv2.COLOR_BGR2RGB)\n",
    "    return RGB, h, w\n",
    "\n",
    "def post_process(input_img, result_list, pic, input_h, input_w):\n",
    "    \"\"\"\n",
    "    后处理\n",
    "    \"\"\"\n",
    "    o_w, o_h = input_img.shape[:2]\n",
    "    # 获取推理结果，进行形状变换\n",
    "    data = result_list[0].reshape(input_h, input_w, 3)\n",
    "    # 进行缩放、颜色转换\n",
    "    output = (cv2.resize(data, (o_w, o_h)) * 255.0).astype(np.uint8)\n",
    "    output_img = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)\n",
    "    # 保存增强后图像\n",
    "    file_name = os.path.join(OUTPUT_DIR, pic)\n",
    "    cv2.imwrite(file_name, output_img)\n",
    "    # 拼接输入图像和增强后图像，返回进行显示\n",
    "    BGR_U8 = np.concatenate([input_img, output_img], axis=1)\n",
    "    return BGR_U8\n",
    "\n",
    "def main():\n",
    "    # 创建推理结果存放路径\n",
    "    if not os.path.exists(OUTPUT_DIR):\n",
    "        os.mkdir(OUTPUT_DIR)\n",
    "    # acl初始化\n",
    "    acl_resource = AclLiteResource()\n",
    "    acl_resource.init()\n",
    "    # 加载模型\n",
    "    model_path = os.path.join(path, \"model/image_HDR_enhance.om\")\n",
    "    model = AclLiteModel(model_path)\n",
    "    # 遍历数据集进行推理\n",
    "    src_dir = os.listdir(INPUT_DIR)\n",
    "    for pic in src_dir:\n",
    "        if not pic.lower().endswith(('.bmp', '.dib', '.png', '.jpg', '.jpeg', '.pbm', '.pgm', '.ppm', '.tif', '.tiff')):\n",
    "            print('it is not a picture, %s, ignore this file and continue,' % pic)\n",
    "            continue\n",
    "        pic_path = os.path.join(INPUT_DIR, pic)\n",
    "        input_img = cv2.imread(pic_path)\n",
    "        # 进行预处理\n",
    "        RGB_image, o_h, o_w = pre_process(pic_path, input_h, input_w)\n",
    "        # 计算推理耗时\n",
    "        start_time = time.time()\n",
    "        # 执行推理\n",
    "        result_list = model.execute([RGB_image, ])\n",
    "        end_time = time.time()\n",
    "        # 打印推理的图片信息和耗时\n",
    "        print('pic:{}'.format(pic))\n",
    "        print('pic_size:{}x{}'.format(o_h, o_w))\n",
    "        print('time:{}ms'.format(int((end_time - start_time) * 1000)))\n",
    "        print('\\n')\n",
    "        # 进行后处理\n",
    "        img_result = post_process(input_img, result_list, pic, input_h, input_w)      \n",
    "        # 显示输入图像和增强后图像\n",
    "        img_RGB = img_result[:, :, [2, 1, 0]] # RGB\n",
    "        plt.axis('off')\n",
    "        plt.xticks([])\n",
    "        plt.yticks([])\n",
    "        plt.imshow(img_RGB)\n",
    "        plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "098e629c",
   "metadata": {},
   "source": [
    "4.运行样例    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1bcd7377",
   "metadata": {},
   "outputs": [],
   "source": [
    "main()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "03aed685",
   "metadata": {},
   "source": [
    "5.运行完成上述样例后，打印出测试图像名称和尺寸、推理时间、资源初始化和释放信息。并展示了输入图像和增强后图像，以及将增强后图像保存在out目录下，可进行查看。  "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dc3b6ba7",
   "metadata": {},
   "source": [
    "## 样例总结\n",
    "样例实现流程包括几个关键步骤：  \n",
    "1.初始化acl资源：在调用acl相关资源时，必须先初始化AscendCL，否则可能会导致后续系统内部资源初始化出错。  \n",
    "2.对输入进行预处理：包括图像归一化、缩放、颜色转换操作。  \n",
    "3.推理：利用AclLiteModel.execute接口进行推理。    \n",
    "4.对推理结果进行后处理：包括形状变换、缩放、颜色转换操作，保存增强后图像。  \n",
    "5.可视化图片：利用plt将结果画出。"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  },
  "vscode": {
   "interpreter": {
    "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
